package main

import (
	"encoding/csv"
	"fmt"
	"io"
	"math"
	"os"
	"reflect"
	"sort"
	"strings"
)

func main() {
	trainDataSet, testDataSet, features := loadDataSet(810)

	var remainLabels []string

	tree := createTree(trainDataSet, features, remainLabels)

	fmt.Println(tree)

	total := 0

	correctNum := 0

	for _, temp := range testDataSet {
		result := classify(tree, features, temp[:len(temp)-1])
		if strings.Compare(result, temp[len(temp)-1]) == 0 {
			correctNum++
		}

		total++
	}

	rate := float64(correctNum) / float64(total) * 100

	fmt.Println("测试集正确率：" + fmt.Sprintf("%.2f", rate) + "%")

	/*var s []string

	s = append(s, "a")

	s = append(s[:0], s[:0]...)

	fmt.Println(s)*/

}

func loadDataSet(trainScale int) ([][]string, [][]string, []string) {

	file, err := os.Open("titanic.csv")
	if err != nil {
		fmt.Println("Error:", err)
		return nil, nil, nil
	}

	defer func(file *os.File) {
		err := file.Close()
		if err != nil {

		}
	}(file)

	reader := csv.NewReader(file)

	var features []string
	var trainDataSet [][]string
	var testDataSet [][]string

	temp, _ := reader.Read()
	//features = append(features, temp[2])
	// 选择标签
	features = append(features, temp[4])
	features = append(features, temp[11])

	curr := 0

	for {
		record, err := reader.Read()

		if err == io.EOF {
			break
		} else if err != nil {
			fmt.Println("Error:", err)
			return nil, nil, nil
		}

		var tempRecord []string

		//tempRecord = append(tempRecord, record[2])
		// 选择特征
		tempRecord = append(tempRecord, record[4])
		if record[11] == "" {
			tempRecord = append(tempRecord, "S")
		} else {
			tempRecord = append(tempRecord, record[11])
		}
		tempRecord = append(tempRecord, record[1])

		if curr <= trainScale {
			trainDataSet = append(trainDataSet, tempRecord)
		} else {
			testDataSet = append(testDataSet, tempRecord)
		}

		curr++
	}

	return trainDataSet, testDataSet, features
}

func calcEnt(data [][]string) float64 {
	// 数据行数
	num := len(data)
	// 记录标签出现的次数
	labelMap := make(map[string]int)
	for _, temp := range data {
		curLabel := temp[len(temp)-1]
		if _, ok := labelMap[curLabel]; !ok {
			labelMap[curLabel] = 0
		}
		labelMap[curLabel]++
	}

	ent := 0.0

	// 计算经验熵
	for _, v := range labelMap {
		prob := float64(v) / float64(num)
		ent -= math.Log2(prob) * prob
	}

	return ent
}

func splitDataSet(dataSet [][]string, axis int, value string) [][]string {
	var res [][]string

	for _, temp := range dataSet {
		if strings.Compare(temp[axis], value) == 0 {
			// 先复制一个切片，防止对数据集的修改
			tar := make([]string, len(temp))
			copy(tar, temp)

			reduceFeatVec := tar[:axis]
			reduceFeatVec = append(reduceFeatVec, tar[axis+1:]...)
			//fmt.Println(reduceFeatVec)
			res = append(res, reduceFeatVec)
		}
	}

	return res
}

func chooseBestFeature(dataSet [][]string) int {
	// 特征数量
	featureNum := len(dataSet[0]) - 1
	// 计算数据集的熵
	baseEntropy := calcEnt(dataSet)
	// 信息增益
	bestInfoGain := 0.0
	// 最优特征的索引值
	bestFeatureIdx := -1
	// 遍历所有特征
	for i := 0; i < featureNum; i++ {
		// 获取某一列的所有特征值
		var featList []string
		for _, temp := range dataSet {
			featList = append(featList, temp[i])
		}
		// 获取不同的特征值
		uniqueFeatureValues := distinct(featList)
		// 经验条件熵
		newEntropy := 0.0
		// 计算信息增益
		for _, temp := range uniqueFeatureValues {
			// 划分子集
			subDataSet := splitDataSet(dataSet, i, temp.(string))
			// 计算子集的概率
			prob := float64(len(subDataSet)) / float64(len(dataSet))
			// 计算经验条件熵
			newEntropy += prob * calcEnt(subDataSet)
		}
		// 信息增益
		infoGain := baseEntropy - newEntropy
		// 打印每个特征的信息增益
		fmt.Printf("特征%d的增益为%.3f\n", i, infoGain)
		// 计算信息增益
		if infoGain > bestInfoGain {
			// 更新信息增益，找到最大的信息增益
			bestInfoGain = infoGain
			// 记录信息增益最大的特征的索引
			bestFeatureIdx = i
		}
	}

	return bestFeatureIdx
}

func vote(classList []string) string {
	classMap := make(map[string]int)
	// 记录特征值出现的次数
	for _, temp := range classList {
		if _, ok := classMap[temp]; !ok {
			classMap[temp] = 0
		}
		classMap[temp]++
	}

	// 排序
	type entry struct {
		feature string
		count   int
	}

	var sortedMap []entry

	for k, v := range classMap {
		sortedMap = append(sortedMap, entry{k, v})
	}

	sort.Slice(sortedMap, func(i, j int) bool {
		return sortedMap[i].count > sortedMap[j].count
	})

	return sortedMap[0].feature
}

func createTree(dataSet [][]string, labels []string, remainFeatures []string) map[string]interface{} {
	// 获取分类标签
	var classList []string
	for _, temp := range dataSet {
		classList = append(classList, temp[len(temp)-1])
	}
	// 如果类别相同，就停止划分
	if len(classList) == count(classList, classList[0]) {
		return map[string]interface{}{classList[0]: nil}
	}
	// 返回出现次数最多的类标签
	if len(dataSet[0]) == 1 {
		return map[string]interface{}{vote(classList): nil}
	}
	// 选择最优特征
	bestFeatIdx := chooseBestFeature(dataSet)
	// 获取最优特征的标签
	bestFeatLabel := labels[bestFeatIdx]
	remainFeatures = append(remainFeatures, bestFeatLabel)
	// 根据最优特征的标签生成树
	tree := make(map[string]interface{})
	// 删除已经使用的特征标签
	tar := make([]string, len(labels))
	copy(tar, labels)
	labels = append(tar[:bestFeatIdx], tar[bestFeatIdx+1:]...)
	// 获取最优特征中的属性值
	var featValues []string
	for _, temp := range dataSet {
		featValues = append(featValues, temp[bestFeatIdx])
	}
	// 去掉重复的属性值
	uniqueValues := distinct(featValues)
	// 遍历特征创建决策树
	for _, temp := range uniqueValues {
		if _, ok := tree[bestFeatLabel]; !ok {
			tree[bestFeatLabel] = make(map[string]interface{})
		}
		tree[bestFeatLabel].(map[string]interface{})[temp.(string)] = createTree(splitDataSet(dataSet, bestFeatIdx, temp.(string)), labels, remainFeatures)
	}

	return tree
}

func classify(tree map[string]interface{}, features []string, testVec []string) string {
	// 获取决策树根节点
	var firstStr string
	for k, v := range tree {
		if v == nil {
			return k
		}

		firstStr = k
	}
	root := tree[firstStr].(map[string]interface{})

	featIdx := index(features, firstStr)

	var classLabel string

	for k, v := range root {
		if strings.Compare(testVec[featIdx], k) == 0 {
			if v == nil {
				classLabel = k
			} else {
				classLabel = classify(root[k].(map[string]interface{}), features, testVec)
			}
		}
	}

	return classLabel
}

func index(target []string, value string) int {
	for i, temp := range target {
		if strings.Compare(temp, value) == 0 {
			return i
		}
	}

	return 0
}

func count(target []string, value string) int {
	num := 0
	for _, temp := range target {
		if strings.Compare(temp, value) == 0 {
			num++
		}
	}

	return num
}

func duplicate(a interface{}) (ret []interface{}) {
	va := reflect.ValueOf(a)
	for i := 0; i < va.Len(); i++ {
		if i > 0 && reflect.DeepEqual(va.Index(i-1).Interface(), va.Index(i).Interface()) {
			continue
		}
		ret = append(ret, va.Index(i).Interface())
	}
	return ret
}

func distinct(val []string) []interface{} {
	sort.Strings(val)
	return duplicate(val)
}
